import torch
import os
import numpy as np
import torch.nn.functional as F

from core.rl.agents.ac_agent import ACAgent
from core.modules.subnetworks import Decoder
from core.modules.layers import LayerBuilderParams
from core.utils.general_utils import ParamDict, map_dict, AttrDict
from core.utils.pytorch_utils import ten2ar, avg_grad_norm, TensorModule, check_shape, map2torch, map2np
from core.rl.utils.mpi import sync_networks


class SpritesExpertAgent(ACAgent):
    """Implements Sprites Expert agent."""
    def __init__(self, config):
        super().__init__(config)
        self._hp = self._default_hparams().overwrite(config)
        # build replay buffer
        self.replay_buffer = self._hp.replay(self._hp.replay_params)
        self._update_steps = 0                # counts the number of alpha updates for optional variable schedules

    def _default_hparams(self):
        default_dict = ParamDict({
            'replay': None,           # replay buffer class
            'replay_params': None,    # parameters for replay buffer
            'reward_scale': 1.0,      # SAC reward scale
            'max_speed': 0.05,        # max speed
        })
        return super()._default_hparams().overwrite(default_dict)

    def update(self, experience_batch):
        """No Update"""
        info = AttrDict()
        return info

    def _act(self, obs):
        agent = obs[:2]
        target = obs[2:]
        vel = target-agent
        return AttrDict(action=np.clip(vel / self._hp.max_speed, -1, 1))

    def add_experience(self, experience_batch):
        """Adds experience to replay buffer (used during warmup)."""
        self.replay_buffer.append(experience_batch)
        self._obs_normalizer.update(experience_batch.observation)

    def _run_policy(self, obs):
        """Allows child classes to post-process policy outputs."""
        return self.policy(obs)

    def _prep_action(self, action):
        """Preprocessing of action in case of discrete action space."""
        if len(action.shape) == 1: action = action[:, None]  # unsqueeze for single-dim action spaces
        return action.float()

    def state_dict(self, *args, **kwargs):
        d = super().state_dict()
        return d

    def load_state_dict(self, state_dict, *args, **kwargs):
        super().load_state_dict(state_dict, *args, **kwargs)

    def save_state(self, save_dir):
        """Saves compressed replay buffer to disk."""
        self.replay_buffer.save(os.path.join(save_dir, 'replay'))

    def load_state(self, save_dir):
        """Loads replay buffer from disk."""
        self.replay_buffer.load(os.path.join(save_dir, 'replay'))

    @property
    def schedule_steps(self):
        return self._update_steps

